#include "platform/opengl/opengl_shader.hpp"

namespace Xen {

OpenGLShader::OpenGLShader(const std::string& name, const std::string& vertex_source, const std::string& fragment_source):name_(name) {
    GLuint vertexOpenGLShader = glCreateShader(GL_VERTEX_SHADER);

    const GLchar* source = (const GLchar *)vertex_source.c_str();
    GLCall(glShaderSource(vertexOpenGLShader, 1, &source, nullptr));

    glCompileShader(vertexOpenGLShader);

    GLint isCompiled = 0;
    glGetShaderiv(vertexOpenGLShader, GL_COMPILE_STATUS, &isCompiled);
    if(isCompiled == GL_FALSE) {
        GLint maxLength = 0;
        glGetShaderiv(vertexOpenGLShader, GL_INFO_LOG_LENGTH, &maxLength);

        std::vector<GLchar> infoLog(maxLength);
        glGetShaderInfoLog(vertexOpenGLShader, maxLength, &maxLength, &infoLog[0]);
        GLCall(glDeleteShader(vertexOpenGLShader));

        XEN_CORE_ERROR("{}", infoLog.data());
        XEN_CORE_ASSERT(false, "vertex shader create failed");
    }

    GLuint fragmentOpenGLShader = glCreateShader(GL_FRAGMENT_SHADER);

    source = (const GLchar *)fragment_source.c_str();
    GLCall(glShaderSource(fragmentOpenGLShader, 1, &source, 0));

    glCompileShader(fragmentOpenGLShader);

    glGetShaderiv(fragmentOpenGLShader, GL_COMPILE_STATUS, &isCompiled);
    if (isCompiled == GL_FALSE) {
        GLint maxLength = 0;
        glGetShaderiv(fragmentOpenGLShader, GL_INFO_LOG_LENGTH, &maxLength);

        std::vector<GLchar> infoLog(maxLength);
        glGetShaderInfoLog(fragmentOpenGLShader, maxLength, &maxLength, &infoLog[0]);

        GLCall(glDeleteShader(fragmentOpenGLShader));
        GLCall(glDeleteShader(vertexOpenGLShader));

        XEN_CORE_ERROR("{}", infoLog.data());
        XEN_CORE_ASSERT(false, "fragment shader create failed");
    }

    renderer_id_ = glCreateProgram();

    GLCall(glAttachShader(renderer_id_, vertexOpenGLShader));
    GLCall(glAttachShader(renderer_id_, fragmentOpenGLShader));

    glLinkProgram(renderer_id_);

    GLint isLinked = 0;
    glGetProgramiv(renderer_id_, GL_LINK_STATUS, (int *)&isLinked);
    if (isLinked == GL_FALSE) {
        GLint maxLength = 0;
        glGetProgramiv(renderer_id_, GL_INFO_LOG_LENGTH, &maxLength);

        std::vector<GLchar> infoLog(maxLength);
        glGetProgramInfoLog(renderer_id_, maxLength, &maxLength, &infoLog[0]);

        GLCall(glDeleteProgram(renderer_id_));

        GLCall(glDeleteShader(vertexOpenGLShader));
        GLCall(glDeleteShader(fragmentOpenGLShader));

        XEN_CORE_ERROR("{}", infoLog.data());
        XEN_CORE_ASSERT(false, "shader link failed");
    }

    GLCall(glDetachShader(renderer_id_, vertexOpenGLShader));
    GLCall(glDetachShader(renderer_id_, fragmentOpenGLShader));
}

OpenGLShader::~OpenGLShader() {
    GLCall(glDeleteProgram(renderer_id_));
}

void OpenGLShader::Bind() const {
    GLCall(glUseProgram(renderer_id_));
}

void OpenGLShader::Unbind() const {
    GLCall(glUseProgram(0));
}

void OpenGLShader::UniformMat4(const std::string& name, const glm::mat4& mat) {
    auto loc = getUniformLocation(name);
    GLCall(glUniformMatrix4fv(loc, 1, false, glm::value_ptr(mat)));
}

void OpenGLShader::UniformInt(const std::string& name, int value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform1i(loc, value));
}

void OpenGLShader::UniformFloat4(const std::string &name, const glm::vec4 &value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform4fv(loc, 1, glm::value_ptr(value)));
}

void OpenGLShader::UniformFloat(const std::string &name, float value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform1f(loc, value));
}

void OpenGLShader::UniformMat3(const std::string &name, const glm::mat3 &mat3) {
    auto loc = getUniformLocation(name);
    GLCall(glUniformMatrix3fv(loc, 1, false, glm::value_ptr(mat3)));
}

void OpenGLShader::UniformInt2(const std::string &name, const glm::ivec2& value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform2iv(loc, 1, glm::value_ptr(value)));
}

void OpenGLShader::UniformInt3(const std::string &name, const glm::ivec3& value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform3iv(loc, 1, glm::value_ptr(value)));
}

void OpenGLShader::UniformInt4(const std::string &name, const glm::ivec4& value) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform4iv(loc, 1, glm::value_ptr(value)));
}

void OpenGLShader::UniformFloat2(const std::string &name, const glm::vec2 &vec2) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform2fv(loc, 1, glm::value_ptr(vec2)));
}

void OpenGLShader::UniformFloat3(const std::string &name, const glm::vec3 &vec3) {
    auto loc = getUniformLocation(name);
    GLCall(glUniform3fv(loc, 1, glm::value_ptr(vec3)));
}

GLint OpenGLShader::getUniformLocation(const std::string& name) {
    auto loc = glGetUniformLocation(renderer_id_, name.c_str());
    if (loc == -1)
        XEN_CORE_ASSERT(false, "can't get location " + name);
    return loc;
}

}
